Q = torch.rand((2, 8, 5,64))
# (batch_size, num_heads, seq_len_q, depth) 
K = torch.rand((2, 8, 5,64))
# (batch_size, num_heads, seq_len_q, depth) 
V = torch.rand((2, 8, 5,64))
# (batch_size, num_heads, seq_len_q, depth)
temp_out, temp_att = ScaledDotProductAttention()(Q, K, V,attn_mask)
print(temp_out.shape) # torch.Size([2, 8, 5, 64])
print(temp_att.shape) # torch.Size([2, 8, 5, 5])
